import json
import os
import random

import datasets
import yaml
from PIL import Image

_CITATION = """\
@misc{li2024playground,
      title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
      author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
      year={2024},
      eprint={2402.17245},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
"""

_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""

_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"

_LICENSE = (
    "Playground v2.5 Community License "
    "(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)

IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"

META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"

CONTROL_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"


class MJHQConfig(datasets.BuilderConfig):
    def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
        super(MJHQConfig, self).__init__(
            name=kwargs.get("name", "default"),
            version=kwargs.get("version", "0.0.0"),
            data_dir=kwargs.get("data_dir", None),
            data_files=kwargs.get("data_files", None),
            description=kwargs.get("description", None),
        )
        self.max_dataset_size = max_dataset_size
        self.return_gt = return_gt


class MJHQ(datasets.GeneratorBasedBuilder):
    VERSION = datasets.Version("0.0.0")

    BUILDER_CONFIG_CLASS = MJHQConfig
    BUILDER_CONFIGS = [
        MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset"),
        MJHQConfig(name="MJHQ-control", version=VERSION, description="MJHQ-5K with controls"),
    ]
    DEFAULT_CONFIG_NAME = "MJHQ"

    def _info(self):
        features = datasets.Features(
            {
                "filename": datasets.Value("string"),
                "category": datasets.Value("string"),
                "image": datasets.Image(),
                "prompt": datasets.Value("string"),
                "prompt_path": datasets.Value("string"),
                "image_root": datasets.Value("string"),
                "image_path": datasets.Value("string"),
                "split": datasets.Value("string"),
                "canny_image_path": datasets.Value("string"),
                "cropped_image_path": datasets.Value("string"),
                "depth_image_path": datasets.Value("string"),
                "mask_image_path": datasets.Value("string"),
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
        )

    def _split_generators(self, dl_manager: datasets.download.DownloadManager):
        if self.config.name == "MJHQ":
            meta_path = dl_manager.download(META_URL)
            image_root = dl_manager.download_and_extract(IMAGE_URL)
            return [
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
                ),
            ]
        else:
            assert self.config.name == "MJHQ-control"
            control_root = dl_manager.download_and_extract(CONTROL_URL)
            control_root = os.path.join(control_root, "MJHQ-5000")
            return [
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN,
                    gen_kwargs={"meta_path": os.path.join(control_root, "prompts.yaml"), "image_root": control_root},
                ),
            ]

    def _generate_examples(self, meta_path: str, image_root: str):
        if self.config.name == "MJHQ":
            with open(meta_path, "r") as f:
                meta = json.load(f)

            names = list(meta.keys())
            if self.config.max_dataset_size > 0:
                random.Random(0).shuffle(names)
                names = names[: self.config.max_dataset_size]
                names = sorted(names)

            for i, name in enumerate(names):
                category = meta[name]["category"]
                prompt = meta[name]["prompt"]
                image_path = os.path.join(image_root, category, f"{name}.jpg")
                yield i, {
                    "filename": name,
                    "category": category,
                    "image": Image.open(image_path) if self.config.return_gt else None,
                    "prompt": prompt,
                    "meta_path": meta_path,
                    "image_root": image_root,
                    "image_path": image_path,
                    "split": self.config.name,
                    "canny_image_path": None,
                    "cropped_image_path": None,
                    "depth_image_path": None,
                    "mask_image_path": None,
                }
        else:
            assert self.config.name == "MJHQ-control"
            meta = yaml.safe_load(open(meta_path, "r"))
            names = list(meta.keys())
            if self.config.max_dataset_size > 0:
                random.Random(0).shuffle(names)
                names = names[: self.config.max_dataset_size]
                names = sorted(names)
            for i, name in enumerate(names):
                prompt = meta[name]
                yield i, {
                    "filename": name,
                    "category": None,
                    "image": None,
                    "prompt": prompt,
                    "meta_path": meta_path,
                    "image_root": image_root,
                    "image_path": os.path.join(image_root, "images", f"{name}.png"),
                    "split": self.config.name,
                    "canny_image_path": os.path.join(image_root, "canny_images", f"{name}.png"),
                    "cropped_image_path": os.path.join(image_root, "cropped_images", f"{name}.png"),
                    "depth_image_path": os.path.join(image_root, "depth_images", f"{name}.png"),
                    "mask_image_path": os.path.join(image_root, "mask_images", f"{name}.png"),
                }
